import numpy as np
import torch
import torch.nn as nn
from utils.utils_loss import sigmoid_loss, hinge_loss, logistic_loss, ramp_loss, unhinged_loss, exp_loss
import torch.nn.functional as F
import datetime, os, time
from utils.utils_algo import update_ema, exp_rampup

def BinaryBiased(model, uupr_pos_train_loader, uupr_neg_train_loader, test_loader, args, loss_fn, device):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy:', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    for epoch in range(args.ep):
        model.train()
        for ((pos_x, pos_y), (neg_x, neg_y)) in zip(uupr_pos_train_loader, uupr_neg_train_loader):
            pos_x, neg_x = pos_x.to(device), neg_x.to(device)
            optimizer.zero_grad()
            pos_outputs = model(pos_x)[:,0]
            neg_outputs = model(neg_x)[:,0]
            pos_train_loss = loss_fn(pos_outputs).mean()
            neg_train_loss = loss_fn(-neg_outputs).mean()
            train_loss = pos_train_loss + neg_train_loss
            train_loss.backward()
            optimizer.step()
        model.eval()
        
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)        
        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if epoch >= (args.ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)

def PcompUnbiased(model, given_train_loader, test_loader, args, loss_fn, device):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    prior = args.prior
    
    for epoch in range(args.ep):
        model.train()
        for (X, y) in given_train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(X)[:,0]
            pos_index, neg_index = (y == 1), (y == -1)
            pos_train_loss, neg_train_loss = 0.0, 0.0
            if pos_index.sum() > 0:
                pos_train_loss = (loss_fn(outputs[pos_index]) - prior*loss_fn(-outputs[pos_index])).mean()
            if neg_index.sum() > 0:
                neg_train_loss = (loss_fn(-outputs[neg_index]) - (1-prior)*loss_fn(outputs[neg_index])).mean()
            train_loss = pos_train_loss + neg_train_loss
            train_loss.backward()
            optimizer.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)
        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if epoch >= (args.ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)

def RankPruning(model, given_train_loader, test_loader, args, loss_fn, device):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    prior = args.prior

    inverse_pos_noise_rate = ((1-prior)*(1-prior)) / (1-prior*(1-prior)) # p(y=0|\tilde{y}=1) = ((1-prior)*(1-prior)) / (1-prior*(1-prior))
    inverse_neg_noise_rate = (prior*prior) / (1-prior*(1-prior)) # p(y=1|\tilde{y}=0) = (prior*prior) / (1-prior*(1-prior))    
    pos_noise_rate = 0.5 * prior / (1-prior*(1-prior))   # transition matrix, p(\tilde{y}=0|y=1) = 1/2*prior/(1-prior*(1-prior))
    neg_noise_rate = 0.5 * (1-prior) / (1-prior*(1-prior)) # p(\tilde{y}=1|y=0) = 1/2*prior/(1-prior*(1-prior))

    for epoch in range(args.ep):
        model.train()
        for (X, y) in given_train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(X)[:,0]
            pos_index, neg_index = (y == 1), (y == -1)
            pos_train_loss, neg_train_loss = 0.0, 0.0
            if pos_index.sum() > 0:
                pos_train_loss = loss_fn(outputs[pos_index])
                pos_sorted_index = np.argsort(pos_train_loss.cpu().data) # from small to large
                num_pred_pos = int((1-inverse_pos_noise_rate) * len(pos_sorted_index))
                updated_pos_index = pos_sorted_index[:num_pred_pos]
                pos_train_loss = 1/(1-pos_noise_rate) * pos_train_loss[updated_pos_index].mean()
            if neg_index.sum() > 0:
                neg_train_loss = loss_fn(-outputs[neg_index])
                neg_sorted_index = np.argsort(neg_train_loss.cpu().data) # from small to large
                num_pred_neg = int((1-inverse_neg_noise_rate) * len(neg_sorted_index))
                updated_neg_index = neg_sorted_index[:num_pred_neg]
                neg_train_loss = 1/(1-neg_noise_rate)*neg_train_loss[updated_neg_index].mean()
            train_loss = pos_train_loss + neg_train_loss
            train_loss.backward()
            optimizer.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)        
        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if epoch >= (args.ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)

def PcompTeacher(model, ema_model, given_train_loader, test_loader, args, loss_fn, device):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    prior = args.prior
    global_step = 0
    cons_weight = args.ema_weight
    ema_decay = args.ema_alpha
    #cons_loss = 
    inverse_pos_noise_rate = ((1-prior)*(1-prior)) / (1-prior*(1-prior)) # p(y=0|\tilde{y}=1) = ((1-prior)*(1-prior)) / (1-prior*(1-prior))
    inverse_neg_noise_rate = (prior*prior) / (1-prior*(1-prior)) # p(y=1|\tilde{y}=0) = (prior*prior) / (1-prior*(1-prior))    
    pos_noise_rate = 0.5 * prior / (1-prior*(1-prior))   # transition matrix, p(\tilde{y}=0|y=1) = 1/2*prior/(1-prior*(1-prior))
    neg_noise_rate = 0.5 * (1-prior) / (1-prior*(1-prior)) # p(\tilde{y}=1|y=0) = 1/2*prior/(1-prior*(1-prior))

    for epoch in range(args.ep):
        model.train()
        for (X, y) in given_train_loader:
            X, y = X.to(device), y.to(device)
            global_step += 1
            optimizer.zero_grad()
            outputs = model(X)[:,0]
            update_ema(model, ema_model, ema_decay, global_step)
            with torch.no_grad():
                ema_outputs = ema_model(X)[:,0]
            pos_index, neg_index = (y == 1), (y == -1)
            pos_train_loss, neg_train_loss = 0.0, 0.0
            if pos_index.sum() > 0:
                pos_train_loss = loss_fn(outputs[pos_index])
                pos_sorted_index = np.argsort(pos_train_loss.cpu().data) # from small to large
                num_pred_pos = int((1-inverse_pos_noise_rate) * len(pos_sorted_index))
                updated_pos_index = pos_sorted_index[:num_pred_pos]
                pos_train_loss = 1/(1-pos_noise_rate) * pos_train_loss[updated_pos_index].mean()
            if neg_index.sum() > 0:
                neg_train_loss = loss_fn(-outputs[neg_index])
                neg_sorted_index = np.argsort(neg_train_loss.cpu().data) # from small to large
                num_pred_neg = int((1-inverse_neg_noise_rate) * len(neg_sorted_index))
                updated_neg_index = neg_sorted_index[:num_pred_neg]
                neg_train_loss = 1/(1-neg_noise_rate)*neg_train_loss[updated_neg_index].mean()
            label_loss = pos_train_loss + neg_train_loss
            cons_loss = cons_weight*F.mse_loss(outputs, ema_outputs)
            train_loss = label_loss + cons_loss
            train_loss.backward()
            optimizer.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)
        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if epoch >= (args.ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)

def NoisyUnbiased(model, given_train_loader, test_loader, args, loss_fn, device):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    prior = args.prior

    pos_noise_rate = 0.5 * prior / (1-prior*(1-prior))   # transition matrix, p(\tilde{y}=0|y=1) = 1/2*prior/(1-prior*(1-prior))
    neg_noise_rate = 0.5 * (1-prior) / (1-prior*(1-prior)) # p(\tilde{y}=1|y=0) = 1/2*prior/(1-prior*(1-prior))

    for epoch in range(args.ep):
        model.train()
        for (X, y) in given_train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(X)[:,0]
            pos_index, neg_index = (y == 1), (y == -1)
            sel_pp_train_loss, sel_pn_train_loss, sel_nn_train_loss, sel_np_train_loss = 0.0, 0.0, 0.0, 0.0
            if pos_index.sum() > 0:
                pp_train_loss = loss_fn(outputs[pos_index]) # regard the given p data as real p data
                pn_train_loss = loss_fn(-outputs[pos_index]) # regard the given p data as real n data
                given_positive_loss = (1-neg_noise_rate)*pp_train_loss - pos_noise_rate*pn_train_loss
            if neg_index.sum() > 0:
                nn_train_loss = loss_fn(-outputs[neg_index]) # regard the given n data as real n data
                np_train_loss = loss_fn(outputs[neg_index]) # regard the given n data as real p data
                given_negative_loss = (1-pos_noise_rate)*nn_train_loss - neg_noise_rate*np_train_loss
            train_loss = 1.0/outputs.shape[0] * (given_positive_loss.sum() + given_negative_loss.sum()) / (1-pos_noise_rate-neg_noise_rate)
            train_loss.backward()
            optimizer.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)        

        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if epoch >= (args.ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)


def PcompReLU(model, given_train_loader, test_loader, args, loss_fn, device):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    lda = args.lda
    prior = args.prior
    
    for epoch in range(args.ep):
        model.train()
        for (X, y) in given_train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(X)[:,0]
            pos_index, neg_index = (y == 1), (y == -1)
            pos_train_loss, neg_train_loss = 0.0, 0.0
            if pos_index.sum() > 0 and neg_index.sum() > 0:
                pos_train_loss = torch.max((loss_fn(outputs[pos_index]).mean() - (1-prior)*loss_fn(outputs[neg_index]).mean()), 0.0)
                neg_train_loss = torch.max((loss_fn(-outputs[neg_index]).mean() - prior*loss_fn(-outputs[pos_index]).mean()), 0.0)
            elif pos_index.sum() > 0 and neg_index.sum() == 0:
                pos_train_loss = torch.max(loss_fn(outputs[pos_index]).mean(), 0.0)
                neg_train_loss = torch.max(-prior*loss_fn(-outputs[pos_index]).mean(), 0.0)
            elif pos_index.sum() == 0 and neg_index.sum() > 0:
                pos_train_loss = torch.max(- (1-prior)*loss_fn(outputs[neg_index]).mean(), 0.0)
                neg_train_loss = torch.max(loss_fn(-outputs[neg_index]).mean(), 0.0)
            train_loss = pos_train_loss + neg_train_loss
            train_loss.backward()
            optimizer.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)

        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if epoch >= (args.ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)
def PcompABS(model, given_train_loader, test_loader, args, loss_fn, device):
    test_acc = accuracy_check(loader=test_loader, model=model, device=device)
    print('#epoch 0', ': test_accuracy', test_acc)
    test_acc_list = []
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.wd)
    prior = args.prior
    
    for epoch in range(args.ep):
        model.train()
        for (X, y) in given_train_loader:
            X, y = X.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(X)[:,0]
            pos_index, neg_index = (y == 1), (y == -1)
            pos_train_loss, neg_train_loss = 0.0, 0.0
            if pos_index.sum() > 0 and neg_index.sum() > 0:
                pos_train_loss = torch.abs((loss_fn(outputs[pos_index]).mean() - (1-prior)*loss_fn(outputs[neg_index]).mean()))
                neg_train_loss = torch.abs((loss_fn(-outputs[neg_index]).mean() - prior*loss_fn(-outputs[pos_index]).mean()))
            elif pos_index.sum() > 0 and neg_index.sum() == 0:
                pos_train_loss = torch.abs(loss_fn(outputs[pos_index]).mean())
                neg_train_loss = torch.abs(- prior*loss_fn(-outputs[pos_index]).mean())
            elif pos_index.sum() == 0 and neg_index.sum() > 0:
                pos_train_loss = torch.abs(-(1-prior)*loss_fn(outputs[neg_index]).mean())
                neg_train_loss = torch.abs(loss_fn(-outputs[neg_index]).mean())
            train_loss = pos_train_loss + neg_train_loss
            train_loss.backward()
            optimizer.step()
        model.eval()
        test_acc = accuracy_check(loader=test_loader, model=model, device=device)

        print('#epoch', epoch+1, ': train_loss', train_loss.data.item(), 'test_accuracy', test_acc)
        if epoch >= (args.ep-10):
            test_acc_list.extend([test_acc])
    return np.mean(test_acc_list)

def accuracy_check(loader, model, device):
    with torch.no_grad():
        total, num_samples = 0, 0
        for images, labels in loader:
            labels, images = labels.to(device), images.to(device)
            outputs = model(images)[:,0]
            predicted = (outputs.data >= 0).float()
            predicted[predicted == 0] = -1.0
            total += (predicted == labels).sum().item()
            num_samples += labels.size(0)
    return total / num_samples